# import libraries
from pyspark.sql import SparkSession
from pyspark import SparkConf
from pyspark.sql.types import *

from pyspark.sql.functions import col, count, lit, rand, when

import pandas as pd
from math import ceil

#################################################
# spark config
#################################################
mtaMaster = "spark://192.168.0.182:7077"

conf = SparkConf()
conf.setMaster(mtaMaster)

conf.set("spark.executor.memory", "24g")
conf.set("spark.driver.memory", "26g")
conf.set("spark.cores.max", 96)
conf.set("spark.driver.cores", 8)

conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
conf.set("spark.kryoserializer.buffer", "256m")
conf.set("spark.kryoserializer.buffer.max", "256m")

conf.set("spark.default.parallelism", 24)

conf.set("spark.eventLog.enabled", "true")
conf.set("spark.eventLog.dir", "hdfs://192.168.0.182:9000/eventlog")
conf.set("spark.history.fs.logDirectory", "hdfs://192.168.0.182:9000/eventlog")

conf.set("spark.driver.maxResultSize", "4g")

conf.getAll()

#################################################
# create spark session
#################################################
spark = SparkSession.builder.appName('ML2_HV_v1_NYT_sim1_and_sim3_to_sim2_round1_human_validation').config(conf=conf).getOrCreate()

sc = spark.sparkContext

# check things are working
print(sc)
print(sc.defaultParallelism)
print("SPARK CONTEXT IS RUNNING")

#################################################
# define major topic codes
#################################################

# major topic codes for loop (NO 23 IN THE NYT CORPUS)
majortopic_codes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 100]
#majortopic_codes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 23, 100]

#################################################
# read result data from round 1
#################################################

df_results = spark.read.parquet("hdfs://192.168.0.182:9000/input/ML2_HV_v1_NYT_r1_classified.parquet").repartition(50)

# verdict to integer for the comparison with majortopic later
df_results = df_results.withColumn('verdict', df_results.verdict.cast(IntegerType()))

#################################################
# create table to store sample and validation numbers
#################################################

columns = ["num_classified", "num_sample", "num_non_sample", "num_correct", "num_incorrect", "precision_in_sample", "num_added_to_training"]
df_numbers = pd.DataFrame(index=majortopic_codes, columns=columns)
df_numbers = df_numbers.fillna(0)

#################################################
# create table of samples from results
#################################################

# constants for sample size calculation for 95% confidence with +-0.05 precision confidence interval:
z = 1.96
delta = 0.05
z_delta = z*z*0.5*0.5/(delta*delta)
print("z_delta :", z_delta)

for i in majortopic_codes:
    df_classified = df_results.where(col('verdict') == i)
    num_classified = df_classified.count()
    df_numbers["num_classified"].loc[i] = num_classified
    print("MTC:", i, "num_classified: ", num_classified)
    if num_classified > 100:
        sample_size = ceil(z_delta/(1+1/num_classified*(z_delta-1)))
        print("sample_size: ", sample_size)
        if sample_size < 100:
            sample_size = 100
        df_sample = df_classified.sort('doc_id').withColumn('random', rand()).sort('random').limit(sample_size).drop('random')
        df_sample_num = df_sample.count()
        print("df_sample: ", df_sample_num)
        # separate non-sample from sample elements
        ids_drop = df_sample.select("doc_id")
        df_non_sample = df_classified.join(ids_drop, "doc_id", "left_anti")
        df_numbers["num_sample"].loc[i] = df_sample_num
        df_numbers["num_non_sample"].loc[i] = df_non_sample.count()
    else:
        df_numbers["num_sample"].loc[i] = num_classified
        df_sample = df_classified
        df_non_sample = None

    # create table of all samples and add new sample to it
    if i == 1:
        df_sample_all = df_sample
    else:
        df_sample_all = df_sample_all.union(df_sample)
    #print("MTC:", i, "df_sample_all: ", df_sample_all.count())

    # create table of all non-samples and add new non-sample to it
    if i == 1:
        df_non_sample_all = None

    if df_non_sample != None and df_non_sample_all == None:
        df_non_sample_all = df_non_sample
    elif df_non_sample != None and df_non_sample_all != None:
        df_non_sample_all = df_non_sample_all.union(df_non_sample)
    #print("MTC:", i, "df_non_sample_all: ", df_non_sample_all.count())
    print("MTC:", i)


#################################################
# check precision by majortopic codes
#################################################

# count correctly classified and precision for each majortopic code and write to table of numbers
df_correctly_classified = df_sample_all.where(col('majortopic') == col('verdict'))
for i in majortopic_codes:
    num_correct = df_correctly_classified.where(col('verdict') == i).count()
    df_numbers["num_correct"].loc[i] = num_correct
    df_numbers["precision_in_sample"].loc[i] = num_correct/df_numbers["num_sample"].loc[i]

# count incorrectly classified for debugging and checking
df_incorrectly_classified = df_sample_all.where(col('majortopic') != col('verdict'))
for i in majortopic_codes:
    num_incorrect = df_incorrectly_classified.where(col('verdict') == i).count()
    df_numbers["num_incorrect"].loc[i] = num_incorrect

df_numbers['num_added_to_training'] = df_numbers['num_correct']

print(df_numbers)


#################################################
# add only validated positives to training set
#################################################

# sometimes there will be no non-sample elements
if df_non_sample_all == None:
    df_non_sample_all = "empty"

# the reason for creating these "empty" values, is because they will persist after we clear the
# cache, and we can use them later in the workflow control

# write all tables to parquet before clearing memory
df_correctly_classified.write.parquet("hdfs://192.168.0.182:9000/input/df_correct_replace_temp.parquet", mode="overwrite")
df_incorrectly_classified.write.parquet("hdfs://192.168.0.182:9000/input/df_wrong_replace_temp.parquet", mode="overwrite")
# sometimes there will be no non-sample elements
if df_non_sample_all != "empty":
    df_non_sample_all.write.parquet("hdfs://192.168.0.182:9000/input/df_non_sample_replace_temp.parquet", mode="overwrite")

# write df_numbers to csv
df_numbers.to_csv("ML2_HV_v1_NYT_human_validation_numbers_r1.csv", index=True)

# empty memory
spark.catalog.clearCache()
print("cache cleared")

#################################################
# prepare df_original to add tables to it
#################################################

df_original = spark.read.parquet("hdfs://192.168.0.182:9000/input/ML2_HV_v1_NYT_r1_train_and_remaining_NOTclassified.parquet").repartition(50)
# we need to create a new majortopic column, because we are now adding back in elements with
# potentially new labels
df_original = df_original.withColumnRenamed('majortopic', 'mtc_original')
df_original = df_original.withColumn('majortopic', df_original['mtc_original'])
# finally, create the new train id column
df_original = df_original.withColumn("train_r2", when(df_original["sim"] != 2, 1).otherwise(0))

#################################################
# add df_non_sample_replace back to df_original
#################################################

if df_non_sample_all != "empty":
    print("df_non_sample_replace is NOT empty")

    df_non_sample_replace = spark.read.parquet("hdfs://192.168.0.182:9000/input/df_non_sample_replace_temp.parquet").repartition(50)
    # we need to create a new majortopic column, because we are now adding back in elements with
    # potentially new labels
    df_non_sample_replace = df_non_sample_replace.withColumnRenamed('majortopic', 'mtc_original')
    df_non_sample_replace = df_non_sample_replace.withColumn('majortopic', df_non_sample_replace['mtc_original'])
    # create the new train id column
    df_non_sample_replace = df_non_sample_replace.withColumn("train_r2", lit(0))
    # drop the extra columns to be able to add it back to df_original
    df_non_sample_replace = df_non_sample_replace.drop('verdict')

    # add df_non_sample_replace elements to df_original
    df_original = df_original.union(df_non_sample_replace)

else:
    print("df_non_sample_replace is empty")

#################################################
# add df_correct_replace back to df_original
#################################################

df_correct_replace = spark.read.parquet("hdfs://192.168.0.182:9000/input/df_correct_replace_temp.parquet").repartition(50)
# we need to create a new majortopic column, because we are now adding back in elements with
# potentially new labels
df_correct_replace = df_correct_replace.withColumnRenamed('majortopic', 'mtc_original')
df_correct_replace = df_correct_replace.withColumn('majortopic', df_correct_replace['verdict'])
# create the new train id column
df_correct_replace = df_correct_replace.withColumn("train_r2", lit(1))
# drop the extra columns to be able to add it back to df_original
df_correct_replace = df_correct_replace.drop('verdict')

# add df_correct_replace elements to df_original
df_original = df_original.union(df_correct_replace)

#################################################
# add df_wrong_replace back to df_original
#################################################

df_wrong_replace = spark.read.parquet("hdfs://192.168.0.182:9000/input/df_wrong_replace_temp.parquet").repartition(50)
# we need to create a new majortopic column, because we are now adding back in elements with
# potentially new labels
df_wrong_replace = df_wrong_replace.withColumnRenamed('majortopic', 'mtc_original')
df_wrong_replace = df_wrong_replace.withColumn('majortopic', df_wrong_replace['mtc_original'])
# create the new train id column
df_wrong_replace = df_wrong_replace.withColumn("train_r2", lit(0))
# drop the extra columns to be able to add it back to df_original
df_wrong_replace = df_wrong_replace.drop('verdict')

# add df_wrong_replace elements to df_original
df_original = df_original.union(df_wrong_replace)

#################################################
# final write operations
#################################################

df_original.write.parquet("hdfs://192.168.0.182:9000/input/ML2_HV_v1_NYT_round2_start.parquet", mode="overwrite")

df_original.groupBy("train_r2").count().show(n=30)

# empty memory
spark.catalog.clearCache()
print("cache cleared")

# write to pandas and export to csv for debugging
df_original = spark.read.parquet("hdfs://192.168.0.182:9000/input/ML2_HV_v1_NYT_round2_start.parquet").repartition(50)
df_original = df_original.drop('text', 'words', 'features', 'raw_features').toPandas()
df_original.to_csv("ML2_HV_v1_NYT_round2_starting_table.csv", index=False)

sc.stop()
spark.stop()
